#!/usr/bin/env python3

# Python script to run and analyse MMS test
#
# Tests a range of different schemes
#

#requires: all_tests
#requires: make

from boututils.run_wrapper import shell, shell_safe, launch_safe
from boutdata.collect import collect

from numpy import sqrt, max, abs, mean, array, log, pi

from os.path import join

import pickle

import time



if __name__ == "__main__":
    print("Making MMS advection test")
    shell_safe("make > make.log")

def make_parent():
    print("Making MMS advection test")
    shell_safe("make -C .. > make.log")

# List of NX values to use
#[16, 32, 64, 128, 256, 512, 1024]
nxlist = [16,32,64, 128, 256, 512, 1024]

nproc = 8


dt0 = 0.15

def run_mms(options,exit=True):
    success = True

    err_2_all = []
    err_inf_all = []
    for opts,exp_order in options:
        error_2   = []  # The L2 error (RMS)
        error_inf = []  # The maximum error

        for nx in nxlist:
            # Set the X and Z mesh size

            dx = 2.*pi / (nx)

            args = opts + " mesh:nx="+str(nx+4)+" mesh:dx="+str(dx)+" MZ="+str(nx) #+" solver:timestep="+str(dt0/nx)

            print("  Running with " + args)

            # Delete old data
            shell("rm data/BOUT.dmp.*.nc")

            # Command to run
            cmd = "./advection "+args
            # Launch using MPI
            start_time_ = time.time()
            
            s, out = launch_safe(cmd, nproc=nproc, pipe=True)

            # Save output to log file
            f = open("run.log."+str(nx), "w")
            f.write(out)
            f.close()

            # Collect data
            E_f = collect("E_f", xguards=False, tind=[1,1], info=False, path="data")

            # Average error over domain
            l2 = sqrt(mean(E_f**2))
            linf = max(abs( E_f ))

            error_2.append( l2 )
            error_inf.append( linf )

            if len(error_2) > 1:
                order = log(error_2[-1] / error_2[-2]) / log(.5)
                print(" %d | %7.3f s | %.6f"%(nx,time.time() - start_time_,order))
            else:
                print(" %d | %7.3f s | ?"%(nx,time.time() - start_time_))
                

            print("  -> Error norm: l-2 %f l-inf %f" % (l2, linf))


        dx = 1. / array(nxlist)

        # Calculate convergence order

        order = log(error_2[-1] / error_2[-2]) / log(dx[-1] / dx[-2])
        print("Convergence order = %f" % (order))

        if exp_order-.25 < order < exp_order+0.75:
            print("............ PASS")
        else:
            success = False
            print("............ FAIL")


    if exit:
        _exit(success)
    return success

def _exit(success):
    from sys import exit
    if success:
        print(" => All tests passed")
        exit(0)
    else:
        print(" => Some failed tests")
        exit(1)
